{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ffdeff6a",
   "metadata": {},
   "outputs": [],
   "source": [
    "import seaborn as sns\n",
    "# ^^^ pyforest auto-imports - don't write above this line\n",
    "%load_ext autoreload\n",
    "%autoreload 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8e3bb683",
   "metadata": {},
   "outputs": [],
   "source": [
    "from tqdm.auto import tqdm\n",
    "from copy import deepcopy\n",
    "\n",
    "from torch.nn.parallel import DistributedDataParallel as DDP, DataParallel as DP\n",
    "\n",
    "from dataset import LUDB\n",
    "from cfg import TrainCfg, ModelCfg\n",
    "from trainer import LUDBTrainer\n",
    "from model import ECG_UNET_LUDB\n",
    "from metrics import compute_metrics\n",
    "\n",
    "from torch_ecg.utils import mask_to_intervals"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a4a660d3",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "51af3dea",
   "metadata": {},
   "outputs": [],
   "source": [
    "TrainCfg.db_dir = \"/home/wenhao/Jupyter/wenhao/data/LUDB/\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "aadd22b7",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "14bc6c4a",
   "metadata": {},
   "outputs": [],
   "source": [
    "train_cfg_fl = deepcopy(TrainCfg)\n",
    "train_cfg_fl.use_single_lead = False\n",
    "train_cfg_fl.loss = \"FocalLoss\"\n",
    "\n",
    "train_cfg_ce = deepcopy(TrainCfg)\n",
    "train_cfg_ce.use_single_lead = False\n",
    "train_cfg_ce.loss = \"CrossEntropyLoss\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c135790f",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "01f87d4f",
   "metadata": {},
   "outputs": [],
   "source": [
    "ds_train_fl = LUDB(train_cfg_fl, training=True, lazy=False)\n",
    "# ds_train_ce = LUDB(train_cfg_ce, training=True, lazy=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "656a815c",
   "metadata": {},
   "outputs": [],
   "source": [
    "# ds_train_fl._load_all_data()\n",
    "# ds_train_ce._load_all_data()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5e04170d",
   "metadata": {},
   "outputs": [],
   "source": [
    "ds_val_fl = LUDB(train_cfg_fl, training=False, lazy=False)\n",
    "# ds_val_ce = LUDB(train_cfg_ce, training=False, lazy=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "da72d679",
   "metadata": {},
   "outputs": [],
   "source": [
    "# ds_val_fl._load_all_data()\n",
    "# ds_val_ce._load_all_data()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2afd45d4",
   "metadata": {},
   "outputs": [],
   "source": [
    "# train_cfg_fl.keep_checkpoint_max = 0\n",
    "# train_cfg_fl.monitor = None\n",
    "# train_cfg_fl.n_epochs = 10"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "75b2c1d0",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "id": "400b052b",
   "metadata": {},
   "source": [
    "## dry run: no augmentation, no preprocessing"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0cba2c45",
   "metadata": {},
   "outputs": [],
   "source": [
    "# train_config = deepcopy(TrainCfg)\n",
    "model_config = deepcopy(ModelCfg)\n",
    "\n",
    "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7cdc8aa2",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e6d4bc98",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "model = ECG_UNET_LUDB(model_config.n_leads, model_config)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1981438a",
   "metadata": {},
   "outputs": [],
   "source": [
    "model.module_size_"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9f251708",
   "metadata": {},
   "outputs": [],
   "source": [
    "if torch.cuda.device_count() > 1:\n",
    "    model = DP(model)\n",
    "    # model = DDP(model)\n",
    "model.to(device=device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f981ac5d",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ed900036",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "trainer = LUDBTrainer(\n",
    "    model=model,\n",
    "    model_config=model_config,\n",
    "    train_config=train_cfg_fl,\n",
    "    device=device,\n",
    "    lazy=True,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "67313494",
   "metadata": {},
   "outputs": [],
   "source": [
    "trainer._setup_dataloaders(ds_train_fl, ds_val_fl)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b657c519",
   "metadata": {},
   "outputs": [],
   "source": [
    "bmd = trainer.train()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5bc106c5",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3b39d757",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "09fb1157",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "id": "46fd7d19",
   "metadata": {},
   "source": [
    "## eval and plot"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "93aad572",
   "metadata": {},
   "outputs": [],
   "source": [
    "model, _ = ECG_UNET_LUDB.from_checkpoint(\"/home/wenhao/.cache/torch_ecg/saved_models/BestModel_ECG_UNET_LUDB_epoch136_09-08_10-36_metric_0.97.pth.tar\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c9d1f933",
   "metadata": {},
   "outputs": [],
   "source": [
    "model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d37561c0",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9025759a",
   "metadata": {},
   "outputs": [],
   "source": [
    "model_output = model.inference(ds_val_fl.signals[0])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6665ca52",
   "metadata": {},
   "outputs": [],
   "source": [
    "model_output.mask"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1d3a6626",
   "metadata": {},
   "outputs": [],
   "source": [
    "fig, ax = plt.subplots(figsize=(20,6))\n",
    "ax.plot(ds_val_fl.signals[0][0],color=\"black\")\n",
    "ax2 = ax.twinx()\n",
    "ax2.plot(model_output.mask[0], color=\"red\")\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d961d324",
   "metadata": {},
   "outputs": [],
   "source": [
    "values = ds_val_fl.signals[0][2]\n",
    "mask_labels = np.where(ds_val_fl.labels[0][0]==1)[1]\n",
    "mask_preds = model_output.mask[0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f3daad0b",
   "metadata": {},
   "outputs": [],
   "source": [
    "mapping = {\n",
    "    1: \"P wave\",\n",
    "    2: \"QRS\",\n",
    "    3: \"T wave\",\n",
    "}\n",
    "# colors = ['#4477AA', '#EE6677', '#228833', '#CCBB44', '#66CCEE', '#AA3377', '#BBBBBB']  # bright\n",
    "colors = ['#004488', '#DDAA33', '#BB5566']  # high-contrast\n",
    "pallete = {\n",
    "    \"P wave\": colors[0],\n",
    "    \"QRS\": colors[2],\n",
    "    \"T wave\": colors[1],\n",
    "}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5bec320d",
   "metadata": {},
   "outputs": [],
   "source": [
    "intervals_preds = mask_to_intervals(mask_preds, vals=[1,2,3])\n",
    "intervals_preds = {mapping[k]:v for k,v in intervals_preds.items()}\n",
    "intervals_labels = mask_to_intervals(mask_labels, vals=[1,2,3])\n",
    "intervals_labels = {mapping[k]:v for k,v in intervals_labels.items()}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "85aaceba",
   "metadata": {},
   "outputs": [],
   "source": [
    "# plt.rcParams['xtick.labelsize']=36\n",
    "# plt.rcParams['ytick.labelsize']=36\n",
    "# plt.rcParams['axes.labelsize']=50\n",
    "# plt.rcParams['legend.fontsize']=40\n",
    "plt.rcParams['xtick.labelsize']=24\n",
    "plt.rcParams['ytick.labelsize']=24\n",
    "plt.rcParams['axes.labelsize']=32\n",
    "plt.rcParams['legend.fontsize']=24"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cbded493",
   "metadata": {},
   "outputs": [],
   "source": [
    "from matplotlib.patches import Patch\n",
    "\n",
    "sns.set_style(\"dark\")\n",
    "\n",
    "fig, ax = plt.subplots(figsize=(20,8))\n",
    "fs = ds_val_fl.reader.fs\n",
    "ax.plot(np.arange(len(values)) / fs, values, color=\"black\", lw=1.2)\n",
    "split_y = 0.35\n",
    "ax.set_xlim(-150 / fs, 5150 / fs)\n",
    "ax.set_ylim(-0.5, 0.9)\n",
    "\n",
    "# ax.xaxis.set_major_locator(plt.MultipleLocator(0.2))\n",
    "# ax.yaxis.set_major_locator(plt.MultipleLocator(0.5))\n",
    "# ax.grid(\n",
    "#     which=\"major\", linestyle=\"-\", linewidth=\"0.4\", color=\"red\"\n",
    "# )\n",
    "# ax.xaxis.set_minor_locator(plt.MultipleLocator(0.04))\n",
    "# ax.yaxis.set_minor_locator(plt.MultipleLocator(0.1))\n",
    "# ax.grid(\n",
    "#     which=\"minor\", linestyle=\":\", linewidth=\"0.2\", color=\"gray\"\n",
    "# )\n",
    "# ax.set_xticks(np.arange(0,11,1))\n",
    "\n",
    "for wave, l_itvs in intervals_preds.items():\n",
    "    for itv in l_itvs:\n",
    "        ax.axvspan(itv[0] / fs, itv[1] / fs, ymin=split_y+0.02, color=pallete[wave], alpha=0.4)\n",
    "for wave, l_itvs in intervals_labels.items():\n",
    "    for itv in l_itvs:\n",
    "        ax.axvspan(itv[0] / fs, itv[1] / fs, ymax=split_y-0.02, color=pallete[wave], alpha=0.6)\n",
    "\n",
    "# ax.axhline(0, color=\"red\", linewidth=1, linestyle=\"dotted\")\n",
    "\n",
    "ax.text(-110 / fs, 0.8, \"Lead III\", fontsize=28)\n",
    "ax.text(5300 / fs, -0.48, \"Label Mask\", fontsize=28, rotation=90)\n",
    "ax.text(5300 / fs, 0.15, \"Predicted Mask\", fontsize=28, rotation=90)\n",
    "ax.set_xlabel(\"Time (s)\")\n",
    "ax.set_ylabel(\"Voltage (mV)\")\n",
    "legend_elements = [\n",
    "    Patch(facecolor=v, label=k, alpha=0.5) for k,v in pallete.items()\n",
    "]\n",
    "ax.legend(\n",
    "    handles=legend_elements,\n",
    "    loc=\"lower center\",\n",
    "    bbox_to_anchor=(0.5, 0.99),\n",
    "    ncol=len(pallete),\n",
    "    fancybox=True,\n",
    ");\n",
    "ax.set_xticks(np.arange(0,10.5,0.5));\n",
    "ax.grid(\n",
    "#     which=\"major\", linestyle=\":\", linewidth=\"0.6\", color=\"gray\"\n",
    "    which=\"major\", linewidth=\"1\", color=\"white\"\n",
    ");\n",
    "\n",
    "# plt.savefig(\"./images/ludb-unet-val-example-small.pdf\", dpi=1200, bbox_inches=\"tight\", transparent=False);\n",
    "# plt.savefig(\"./images/ludb-unet-val-example-small.svg\", dpi=1200, bbox_inches=\"tight\", transparent=False);"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a97a7a79",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8dc259f5",
   "metadata": {},
   "outputs": [],
   "source": [
    "from matplotlib.patches import Patch\n",
    "\n",
    "fig, ax = plt.subplots(figsize=(120,12))\n",
    "fs = ds_val_fl.reader.fs\n",
    "ax.plot(np.arange(len(values)) / fs, values, color=\"black\", lw=1.2)\n",
    "split_y = 0.35\n",
    "ax.set_xlim(-150 / fs, 5150 / fs)\n",
    "ax.set_ylim(-0.6, 1.6)\n",
    "\n",
    "ax.xaxis.set_major_locator(plt.MultipleLocator(0.2))\n",
    "ax.yaxis.set_major_locator(plt.MultipleLocator(0.5))\n",
    "ax.grid(\n",
    "    which=\"major\", linestyle=\"-\", linewidth=\"0.4\", color=\"red\"\n",
    ")\n",
    "ax.xaxis.set_minor_locator(plt.MultipleLocator(0.04))\n",
    "ax.yaxis.set_minor_locator(plt.MultipleLocator(0.1))\n",
    "ax.grid(\n",
    "    which=\"minor\", linestyle=\":\", linewidth=\"0.2\", color=\"gray\"\n",
    ")\n",
    "# ax.set_xticks(np.arange(0,11,1))\n",
    "\n",
    "for wave, l_itvs in intervals_preds.items():\n",
    "    for itv in l_itvs:\n",
    "        ax.axvspan(itv[0] / fs, itv[1] / fs, ymin=split_y+0.02, color=pallete[wave], alpha=0.4)\n",
    "for wave, l_itvs in intervals_labels.items():\n",
    "    for itv in l_itvs:\n",
    "        ax.axvspan(itv[0] / fs, itv[1] / fs, ymax=split_y-0.02, color=pallete[wave], alpha=0.6)\n",
    "\n",
    "# ax.axhline(0, color=\"red\", linewidth=2, linestyle=\"dotted\")\n",
    "\n",
    "ax.text(-110 / fs, 1.2, \"Lead III\", fontsize=28)\n",
    "ax.text(5200 / fs, -0.55, \"Label Mask\", fontsize=28, rotation=90)\n",
    "ax.text(5200 / fs, 0.35, \"Predicted Mask\", fontsize=28, rotation=90)\n",
    "ax.set_xlabel(\"Time (s)\")\n",
    "ax.set_ylabel(\"Voltage (mV)\")\n",
    "legend_elements = [\n",
    "    Patch(facecolor=v, label=k, alpha=0.5) for k,v in pallete.items()\n",
    "]\n",
    "ax.legend(\n",
    "    handles=legend_elements,\n",
    "    loc=\"lower left\",\n",
    "#     bbox_to_anchor=(0.5, 0.99),\n",
    "#     ncol=len(pallete),\n",
    "    fancybox=True,\n",
    ");\n",
    "# ax.set_xticks(np.arange(0,10.5,0.5));\n",
    "# ax.grid(\n",
    "#     which=\"major\", linestyle=\":\", linewidth=\"0.6\", color=\"gray\"\n",
    "# );\n",
    "\n",
    "# plt.savefig(\"./images/ludb-unet-val-example-large.pdf\", dpi=1200, bbox_inches=\"tight\", transparent=False);\n",
    "# plt.savefig(\"./images/ludb-unet-val-example-large.svg\", dpi=1200, bbox_inches=\"tight\", transparent=False);"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8345624e",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0353a030",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "id": "91a4e915",
   "metadata": {},
   "source": [
    "## gather results"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e84f895e",
   "metadata": {},
   "outputs": [],
   "source": [
    "import seaborn as sns\n",
    "from matplotlib.pyplot import cm\n",
    "sns.set()\n",
    "colors = plt.rcParams['axes.prop_cycle'].by_key()['color']\n",
    "markers = [\"*\", \"v\", \"p\", \"d\", \"s\", \"$\\heartsuit$\", \"+\", \"x\", ]\n",
    "marker_size = 9\n",
    "plt.rcParams['xtick.labelsize']=24\n",
    "plt.rcParams['ytick.labelsize']=24\n",
    "plt.rcParams['axes.labelsize']=32\n",
    "plt.rcParams['legend.fontsize']=20\n",
    "\n",
    "marker_size = 9"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e65478d9",
   "metadata": {},
   "outputs": [],
   "source": [
    "df_res = pd.read_csv(\"/home/wenhao/Jupyter/wenhao/workspace/torch_ecg/benchmarks/train_unet_ludb/log/TorchECG_04-06_22-30_ECG_UNET_LUDB_adamw_amsgrad_LR_0.001_BS_32.csv\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7e7c38f3",
   "metadata": {},
   "outputs": [],
   "source": [
    "df_res"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cde53e02",
   "metadata": {
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "fig, ax = plt.subplots(figsize=(16,12))\n",
    "\n",
    "line_width = 2.5\n",
    "\n",
    "df_val = df_res[df_res.part==\"val\"]\n",
    "df_train = df_res[df_res.part==\"train\"].dropna(subset=[\"loss\"])\n",
    "lns1 = ax.plot(\n",
    "    df_val.epoch, df_val.f1_score,\n",
    "    marker=markers[0], linewidth=line_width, color=colors[0], markersize=marker_size, label=\"val-f1-score\",\n",
    ")\n",
    "ax.set_xlabel(\"Epochs (n.u.)\", fontsize=36)\n",
    "ax.set_ylabel(\"f1 score (n.u.)\", fontsize=36)\n",
    "ax.set_ylim(-0.1,1)\n",
    "ax2 = ax.twinx()\n",
    "lns2 = ax2.plot(\n",
    "    df_train.epoch, df_train.loss,\n",
    "    marker=markers[1], linewidth=line_width, color=colors[1], markersize=marker_size, label=\"train-loss\",\n",
    ")\n",
    "ax2.set_ylabel(\"Loss (n.u.)\", fontsize=36)\n",
    "ax2.set_ylim(-0.03,0.3)\n",
    "ax2.set_yticks(np.arange(0,0.35,0.06))\n",
    "\n",
    "lns = lns1+lns2\n",
    "labs = [l.get_label() for l in lns]\n",
    "ax.legend(lns, labs, loc=\"lower right\", fontsize=26)\n",
    "\n",
    "# plt.savefig(\"./results/ludb-unet-score-loss.pdf\", dpi=1200, bbox_inches=\"tight\", transparent=False);\n",
    "# plt.savefig(\"./results/ludb-unet-score-loss.svg\", dpi=1200, bbox_inches=\"tight\", transparent=False);"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fde67e14",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "97789c4e",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
